查看原文
其他

基于OpenCV实现手写体数字训练与识别

gloomyfish OpenCV学堂 2019-03-29

OpenCV实现手写体数字训练与识别

机器学习(ML)是OpenCV模块之一,对于常见的数字识别与英文字母识别都可以做到很高的识别率,完成这类应用的主要思想与方法是首选对训练图像数据完成预处理与特征提取,根据特征数据组成符合OpenCV要求的训练数据集与标记集,然后通过机器学习的KNN、SVM、ANN等方法完成训练,训练结束之后保存训练结果,对待检测的图像完成分割、二值化、ROI等操作之后,加载训练好的分类数据,就可以预言未知分类。

一:数据集

这里使用的数据集是mnist 手写体数字数据集、关于数据集的具体说明如下:

数据集名称说明
train-images-idx3-ubyte.gz训练图像28x28大小,6万张
train-labels-idx1-ubyte.gz每张图像的数字标记,6万条
t10k-images-idx3-ubyte.gz测试数据集、1万张图像28x28
t10k-labels-idx1-ubyte.gz测试数据集标记,表示图像数字

上述数据集数据组成内部结构,图像是以灰度每个字节表示一个像素点的灰度值,图像的总数、宽与高的大小从开始位置读取,说明如下:

开始移位类型描述
00004字节int类型0x00000803(2051)魔数
00044字节int类型60000图像数目
00084字节int类型28图像高度
000124字节int类型28图像宽度

标记部分数据组成如下:

开始移位类型描述
00004字节int类型0x00000801(2049)魔数
00044字节int类型60000标记数目
00081字节ubyte??对应图像数字
00091字节ubyte??对应图像数字


读取标记数据集读取图像数据集代码如下:

  1. Mat readImages(int opt) {

  2.    int idx = 0;

  3.    ifstream file;

  4.    Mat img;

  5.    if (opt == 0)

  6.    {

  7.        cout << "\n Training...";

  8.        file.open("D:/vcprojects/images/mnist/train-images.idx3-ubyte", ios::binary);

  9.    }

  10.    else

  11.    {

  12.        cout << "\n Test...";

  13.        file.open("D:/vcprojects/images/mnist/t10k-images.idx3-ubyte", ios::binary);

  14.    }

  15.    // check file

  16.    if (!file.is_open())

  17.    {

  18.        cout << "\n File Not Found!";

  19.        return img;

  20.    }

  21.    /*

  22.    byte 0 - 3 : Magic Number(Not to be used)

  23.    byte 4 - 7 : Total number of images in the dataset

  24.    byte 8 - 11 : rows of each image in the dataset

  25.    byte 12 - 15 : cols of each image in the dataset

  26.    */

  27.    int magic_number = 0;

  28.    int number_of_images = 0;

  29.    int height = 0;

  30.    int width = 0;

  31.    file.read((char*)&magic_number, sizeof(magic_number));

  32.    magic_number = reverseDigit(magic_number);

  33.    file.read((char*)&number_of_images, sizeof(number_of_images));

  34.    number_of_images = reverseDigit(number_of_images);

  35.    file.read((char*)&height, sizeof(height));

  36.    height = reverseDigit(height);

  37.    file.read((char*)&width, sizeof(width));

  38.    width = reverseDigit(width);

  39.    Mat train_images = Mat(number_of_images, height*width, CV_8UC1);

  40.    cout << "\n No. of images:" << number_of_images <<endl;

  41.    Mat digitImg = Mat::zeros(height, width, CV_8UC1);

  42.    for (int i = 0; i < number_of_images; i++) {

  43.        int index = 0;  

  44.        for (int r = 0; r<height; ++r) {

  45.            for (int c = 0; c<width; ++c) {

  46.                unsigned char temp = 0;

  47.                file.read((char*)&temp, sizeof(temp));

  48.                index = r*width + c;

  49.                train_images.at<uchar>(i, index) = (int)temp;

  50.                digitImg.at<uchar>(r, c) = (int)temp;

  51.            }

  52.        }

  53.        if (i < 100) {

  54.            imwrite(format("D:/vcprojects/images/mnist/images/digit_%d.png", i), digitImg);

  55.        }

  56.    }

  57.    train_images.convertTo(train_images, CV_32FC1);

  58.    return train_images;

  59. }

  1. Mat readLabels(int opt) {

  2.    int idx = 0;

  3.    ifstream file;

  4.    Mat img;

  5.    if (opt == 0)

  6.    {

  7.        cout << "\n Training...";

  8.        file.open("D:/vcprojects/images/mnist/train-labels.idx1-ubyte");

  9.    }

  10.    else

  11.    {

  12.        cout << "\n Test...";

  13.        file.open("D:/vcprojects/images/mnist/t10k-labels.idx1-ubyte");

  14.    }

  15.    // check file

  16.    if (!file.is_open())

  17.    {

  18.        cout << "\n File Not Found!";

  19.        return img;

  20.    }

  21.    /*

  22.    byte 0 - 3 : Magic Number(Not to be used)

  23.    byte 4 - 7 : Total number of labels in the dataset

  24.    */

  25.    int magic_number = 0;

  26.    int number_of_labels = 0;

  27.    file.read((char*)&magic_number, sizeof(magic_number));

  28.    magic_number = reverseDigit(magic_number);

  29.    file.read((char*)&number_of_labels, sizeof(number_of_labels));

  30.    number_of_labels = reverseDigit(number_of_labels);

  31.    cout << "\n No. of labels:" << number_of_labels << endl;

  32.    Mat labels = Mat(number_of_labels, 1, CV_8UC1);

  33.    for (long int i = 0; i<number_of_labels; ++i)

  34.    {

  35.        unsigned char temp = 0;

  36.        file.read((char*)&temp, sizeof(temp));

  37.        //printf("temp : %d\n ", temp);

  38.        labels.at<uchar>(i, 0) = temp;

  39.    }

  40.    labels.convertTo(labels, CV_32SC1);

  41.    return labels;

  42. }

二:训练与测试

对上述数据集,我们不使用提取特征方式,而是采用纯像素数据作为输入,分别使用KNN与SVM对数据集进行训练与测试,比较他们最终的识别率。

KNN方式

KNN是最简单的机器学习方法、主要是计算目标与模型之间的空间向量距离得到最终预测分类结果。训练的代码如下:

  1. Ptr<ml::KNearest> knn = ml::KNearest::create();

  2. knn->setDefaultK(5);

  3. knn->setIsClassifier(true);

  4. Ptr<ml::TrainData> tdata = ml::TrainData::create(train_images, ml::ROW_SAMPLE, train_labels);

  5. knn->train(tdata);

测试代码如下:

  1. void testMnist() {

  2.    //Ptr<ml::SVM> svm = Algorithm::load<ml::SVM>("D:/vcprojects/images/mnist/knn_knowledge.yml"); // SVM-POLY - 98%

  3.    Ptr<ml::KNearest> knn = Algorithm::load<ml::KNearest>("D:/vcprojects/images/mnist/knn_knowledge.yml"); // KNN - 97%

  4.    Mat train_images = readImages(1);

  5.    Mat train_labels = readLabels(1);

  6.    printf("\n read mnist test dataset successfully...\n");

  7.    float total = train_images.rows;

  8.    float correct = 0;

  9.    Rect rect;

  10.    rect.x = 0;

  11.    rect.height = 1;

  12.    rect.width = (28 * 28);

  13.    for (int i = 0; i < total; i++) {

  14.        int actual = train_labels.at<int>(i);

  15.        rect.y = i;

  16.        Mat oneImage = train_images(rect);

  17.        //int digit = svm->predict(oneImage);

  18.        Mat result;

  19.        float predicted = knn->predict(oneImage, result);

  20.        int digit = static_cast<int>(predicted);

  21.        if (digit == actual) {

  22.            correct++;

  23.        }

  24.    }

  25.    printf("\n recognize rate : %.2f \n", correct / total);

  26. }

上述KNN基于纯像素方式的数据训练与测试准确率高达97%。

SVM方式

SVM的全称是支掌向量机,本来是用来对数据进行二分类的预测与分析、后来扩展到可以对数据进行回归与多分类预测与分析,主要是把数据映射到高维数据空间、把靠近高维数据的部分称为支掌向量(SV)。SVM根据使用的核不同、参数不同,可以得到不同的分类与预测结果、所以在OpenCV中使用SVM做分类的时候,尽量推荐大家使用train_auto方法来训练、但是trainauto运行时间一般都会比较久,有时候可能长达数天。 SVM的训练代码如下:

  1. // 创建与初始化

  2. Ptr<cv::ml::SVM> svm = ml::SVM::create();

  3. svm->setType(ml::SVM::C_SVC);

  4. svm->setKernel(ml::SVM::POLY);

  5. svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

  6. svm->setGamma(3);

  7. svm->setDegree(3);

  8. // SVM训练mnist数据集分类

  9. svm->train(train_images, ml::ROW_SAMPLE, train_labels);

关于参数设置部分、更加详细的可以参加OpenCV机器学习模块API说明,影响最终识别率的因素有很多,其中SVM训练收敛终止条件的最终循环数大小跟运行时间训练时间有关系,实验证明1e4/1e3的效果都比较好,我采用1e3,对测试数数据做预测、准确率达到98%。其测试代码跟上面KNN的极其类似。这里不再给出。

三:应用

训练好的数据保存在本地,初始化加载,使用对象的识别方法就可以预测分类、进行对象识别。当然这么做,还需要对输入的手写数字图像进行二值化、分割、调整等预处理之后才可以传入进行预测。完整的步骤如下:

以下是两个测试图像识别结果:

演示一截屏:


演示二截屏:

注意点:

最终要把图像Mat对象转换为CV_32FC1的灰度,使用resharp转换为行模式,调用svm->predict(Mat)即可返回预测结果。

善学者,假人之长以补其短


关注【OpenCV学堂】

长按或者扫码下面二维码即可关注

+OpenCV学习群 376281510

进群暗号:OpenCV


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存